from qubiter.adv_applications.StairsDeriv import *
from qubiter.adv_applications.MeanHamil import *
from qubiter.device_specific.Qubiter_to_RigettiPyQuil import *
from qubiter.device_specific.RigettiTools import *
import qubiter.utilities_gen as utg
from qubiter.CGateExpander import *

import itertools as it
import copy as cp

from openfermion.ops import QubitOperator

from pyquil.quil import Program
from pyquil.api import QVMConnection
from pyquil.gates import *
from pyquil import get_qc
from pyquil.api import WavefunctionSimulator
# from pyquil.reference_simulator import ReferenceWavefunctionSimulator


class StairsDeriv_rigetti(StairsDeriv):
    """
    This class is a child of StairsDeriv. Its main purpose is to override
    the method get_mean_val() of its abstract parent class StairsDeriv. In
    this class, the simulation necessary to evaluate the output of
    get_mean_val() is done by Rigetti Pyquil simulators or their physical qc
    device.

    Attributes
    ----------
    qc : QuantumComputer
        returned by PyQuil method get_qc()
    translation_line_list : list[str]
        a list of lines of PyQuil code generated by the translator. The
        lines all start with "pg +=".
    translator : Qubiter_to_RigettiPyQuil

    """
    def __init__(self, qc, deriv_gate_str, gate_str_to_rads_list,
                 file_prefix, parent_num_qbits, hamil, **kwargs):
        """
        Constructor

        Parameters
        ----------
        qc : QuantumComputer
        deriv_gate_str : str
        gate_str_to_rads_list : dict[str, list[float|str]]
        file_prefix : str
        parent_num_qbits : int
        hamil : QubitOperator
        kwargs : dict
            key-word arguments of MeanHamil

        Returns
        -------

        """
        StairsDeriv.__init__(self, deriv_gate_str, gate_str_to_rads_list,
                 file_prefix, parent_num_qbits, hamil, **kwargs)

        self.qc = qc
        self.translator = None
        self.translation_line_list = []

    def get_mean_val(self, var_num_to_rads):
        """
        This method returns a list partials_list consisting of 4 floats
        which are the partial derivatives wrt the 4 possible derivative
        directions ( deriv_direc), of the multi-controlled gate U specified
        by self.deriv_gate_str.

        Parameters
        ----------
        var_num_to_rads : dict[int, float]

        Returns
        -------
        list[float]

        """
        partials_list = [0., 0., 0., 0.]
        # number of bits with (i.e., including) ancilla
        num_qbits_w_anc = self.num_qbits
        for has_neg_polarity, deriv_direc in it.product(
                *[[False, True], range(4)]):
            if self.deriv_gate_str == 'prior':
                if has_neg_polarity:
                    has_neg_polarity = None
                else:
                    continue  # this skips iteration in loop
            for dpart_name in StairsDeriv.dpart_dict[deriv_direc]:
                emb = CktEmbedder(num_qbits_w_anc, num_qbits_w_anc)
                wr = StairsDerivCkt_writer(self.deriv_gate_str,
                    has_neg_polarity, deriv_direc, dpart_name,
                        self.gate_str_to_rads_list, self.file_prefix, emb)
                wr.close_files()
                # wr.print_pic_file()
                # wr.print_eng_file()
                t_list = self.gate_str_to_rads_list[self.deriv_gate_str]
                coef_of_dpart = StairsDerivCkt_writer.\
                    get_coef_of_dpart(t_list, deriv_direc,
                                      dpart_name, var_num_to_rads)
                fun_name_to_fun = StairsDerivCkt_writer.\
                    get_fun_name_to_fun(t_list, deriv_direc, dpart_name)

                vman = PlaceholderManager(
                        var_num_to_rads=var_num_to_rads,
                        fun_name_to_fun=fun_name_to_fun)

                # CGateExpander and the translator Qubiter_to_RigettiPyQuil
                # are both children of SEO_reader. SEO_reader and any of its
                #  subclasses will accept a vman ( object of
                # PlaceholderManager) in one of its keyword args. If a
                # SEO_reader is given a vman as input, it will use it to
                # replace placeholder variable strings by floats.

                # PyQuil does not support multi-controlled u2 gates so
                # expand them to lowest common denominator, CNOTs and single
                #  qubit gates, using CGateExpander. Give CGateExpander a
                # vman input so as to float all variables before expansion

                expan = CGateExpander(self.file_prefix, num_qbits_w_anc,
                              vars_manager=vman)
                # this gives name of new file with expansion
                out_file_prefix = SEO_reader.xed_file_prefix(self.file_prefix)
                # expan.wr.print_pic_file()
                # expan.wr.print_eng_file()

                # this creates a file with all PyQuil gates that are
                # independent of hamil.
                self.translator = Qubiter_to_RigettiPyQuil(
                    out_file_prefix, self.num_qbits,
                    aqasm_name='RigPyQuil', prelude_str='', ending_str='')
                with open(utg.preface(self.translator.aqasm_path), 'r') as fi:
                    self.translation_line_list = fi.readlines()

                pg = Program()
                for line in self.translation_line_list:
                    line = line.strip('\n')
                    if line:
                        exec(line)
                len_pg_in = len(pg)
                for term, coef in self.hamil.terms.items():
                    # we have checked before that coef is real
                    coef = complex(coef).real

                    # print('nnnnnbbbbb', term)
                    new_term = tuple(list(term) + [(num_qbits_w_anc-1, 'X')])
                    # print('jjjjjjj', new_term)

                    # Throw out previous coda.
                    # Remember bug in Pyquil. Slicing a program turns it into
                    # a list
                    pg = Program(pg[:len_pg_in])

                    # add measurement coda for this term of hamil
                    # and for X at ancilla
                    bit_pos_to_xy_str =\
                        {bit: action for bit, action in new_term
                         if action != 'Z'}
                    RigettiTools.add_xy_meas_coda_to_program(
                        pg, bit_pos_to_xy_str)

                    # get effective state vec
                    if self.num_samples:
                        # send and receive from cloud, get obs_vec
                        bitstrings = self.qc.run_and_measure(pg,
                                                trials=self.num_samples)
                        obs_vec = RigettiTools.obs_vec_from_bitstrings(
                                bitstrings, self.num_qbits, bs_is_array=False)

                        # go from obs_vec to effective state vec
                        counts_dict = StateVec.get_counts_from_obs_vec(
                            self.num_qbits, obs_vec)
                        emp_pd = StateVec.get_empirical_pd_from_counts(
                            self.num_qbits, counts_dict)
                        effective_st_vec = StateVec.\
                            get_emp_state_vec_from_emp_pd(
                                self.num_qbits, emp_pd)
                    else:  # num_samples = 0
                        sim = WavefunctionSimulator()
                        st_vec_arr = sim.wavefunction(pg).amplitudes
                        st_vec_arr = st_vec_arr.reshape([2]*self.num_qbits)
                        perm = list(reversed(range(self.num_qbits)))
                        st_vec_arr = np.transpose(st_vec_arr, perm)
                        effective_st_vec = StateVec(self.num_qbits, st_vec_arr)
                    # add contribution to mean
                    real_arr = self.get_real_vec(new_term)
                    mean_val_change = coef*effective_st_vec.\
                            get_mean_value_of_real_diag_mat(real_arr)
                    mean_val_change *= coef_of_dpart
                    if has_neg_polarity:
                        mean_val_change *= -1
                    partials_list[deriv_direc] += mean_val_change
        return partials_list


if __name__ == "__main__":
    def main():
        print(5)
    main()
